Skip to content

Conversation

alexbaden
Copy link
Contributor

Introduces a verifier to ensure the DPAS layout attached to a Dot operation has a suitable opsPerChannel param for the A and B operand inputs to the Dot op. Previously this verification was implicit in the Triton GEN verification, producing a somewhat cryptic error message (prior to #4276 there was no error message):

test.mlir:16:11: error: 'triton_gen.dpas' op the dimension for the 2nd operand (A) should be equal to half of the repeat count
    %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #3935 

with this new verifier, the error message is more user friendly:

test.mlir:16:11: error: unexpected error: Operand 2 (%0 = "arith.constant"() <{value = dense<0.000000e+00> : tensor<32x32xf32, #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>>}> : () -> tensor<32x32xf32, #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>>) has an invalid layout: #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>.
Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.
    %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>

close #4270

@chengjunlu
Copy link
Contributor

chengjunlu commented May 29, 2025

Can we move the verifier code to this function which is defined specific for DotOp

verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,

This interface is used dedicate to verify the DotOp which is more efficient to reduce compiling time.:

LogicalResult DotOp::verify() {
  auto aTy = getA().getType();
  auto bTy = getB().getType();
  if (aTy.getElementType().getIntOrFloatBitWidth() !=
      bTy.getElementType().getIntOrFloatBitWidth())
    return emitError(
        "element types of operands A and B must have same bit width");
  auto aEncoding = aTy.getEncoding();
  auto bEncoding = bTy.getEncoding();
  if (!aEncoding && !bEncoding)
    return success();
  // Verify that the encodings are valid.
  if (!aEncoding || !bEncoding)
    return emitError("mismatching encoding between A and B operands");
  auto accTy = getC().getType();
  auto retEnc = accTy.getEncoding();
  if (!retEnc)
    return emitError("miss encoding of C operand");
  Dialect &dialect = retEnc.getDialect();
  auto interface = cast<DialectInferLayoutInterface>(&dialect);
  return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding,
                                                     bEncoding);
}

@@ -957,6 +1018,7 @@ void TritonIntelGPUDialect::initialize() {
>();

addInterfaces<TritonIntelGPUInferLayoutInterface>();
addInterfaces<TritonIntelGPUVerifyTensorLayoutInterface>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an discussion with upstream Triton.
We want the third party dialect can use the TritonGPUVerify Interface as the parent class.
@LiyangLingIntel , Do you know what is the response of the upstream and what is the issue for the discussion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The verify tensor layout interface is called via

Dialect &dialect = layout.getDialect();
    auto verifyLayoutInterface =
        dyn_cast<mlir::triton::DialectVerifyTensorLayoutInterface>(&dialect);
    if (verifyLayoutInterface) {
      return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op,
                                                       makeErr);
    }

note that the dialect comes from the layout attribute and not the operation. Why would we need to call the Triton GPU dialect interface / use it as the parent class, when the layouts (attributes) it operates on are not part of our dialect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what is the response of the upstream and what is the issue for the discussion.

It seems there is no further design update after that discussion.

Why would we need to call the Triton GPU dialect interface / use it as the parent class, when the layouts (attributes) it operates on are not part of our dialect?

As what I can recall, there were some cases that layouts from Triton GPU dialect would also go into Triton Intel GPU dialect verify/infer layout interface. The reason "use the TritonGPUVerify Interface as the parent class" is to reuse common code to reduce duplication.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some basic legal check implemented in the Triton GPU dialect interface which is valid for the third_party GPU dialect as well.

But right now those basic check is missed when to check the layout defined in third_party.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about inferLayoutInterface, but I am curious about verifyLayoutInterface as the dialect comes directly from the layout, and as far as I know there are no layouts shared between dialects.

Now if the layout were a DotOperandEncoding layout with parent from the Intel dialect I could see how that might pose a problem, as DotOperandEncoding would never hit the Intel dialect verifier. But I don't understand how the opposite could be true.

@alexbaden
Copy link
Contributor Author

verifyDotOpEncodingCompatibility may indeed be a better place to put this. When I initially started I was hoping to do the verification on all DPAS layouts attached to any op, but it quickly became clear that, since OpsPerChannel is a property of the operands, we needed to only verify dot op.

As an aside, are we sure OpsPerChannel should be part of the layout? It is implicit in the tensor type of the arguments to the dot operation and could likely be derived given the A and B encodings.

@chengjunlu
Copy link
Contributor

chengjunlu commented May 29, 2025

verifyDotOpEncodingCompatibility may indeed be a better place to put this. When I initially started I was hoping to do the verification on all DPAS layouts attached to any op, but it quickly became clear that, since OpsPerChannel is a property of the operands, we needed to only verify dot op.

As an aside, are we sure OpsPerChannel should be part of the layout? It is implicit in the tensor type of the arguments to the dot operation and could likely be derived given the A and B encodings.

The OpsPerChannel carries the redundant information as the kWidth of the DotOpLayout indeed from IR aspect. To be more precisely and simple, the OpsPerChannel is not useful.

But it gives better experience for reviewing the IR manually from my feeling.

@alexbaden
Copy link
Contributor Author

@chengjunlu does this duplicate #4469 ?

@chengjunlu
Copy link
Contributor

@chengjunlu does this duplicate #4469 ?

I think it is not duplicated.
The interface in #4469 is only to verify the DotOp layout. And it is blocked by upstream Triton review.

The verifier in this PR is special to verify the combination of the scalar type and DotOp layout of the tt.dot operations.

@alexbaden alexbaden force-pushed the alex/verify_dpas_tensor_layout branch from 48f361e to c38da38 Compare June 22, 2025 18:38
@alexbaden
Copy link
Contributor Author

rebased + addressed comments

%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>

// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf16, #dot_operand_b> -> tensor<32x32xf32, #dpas>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just keep it the way it was - I am comfortable knowing it tests both permutations and I think the likelihood that they get changed is very low.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly about this. Does the test fail with the suggested changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the original set of suggested changes caused a failure but I did not reproduce locally. It seemed safer and more efficient to use the original test which produces the expected results, and which I verified carefully by disabling validation on A and/or B.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tried locally, and it works, it likely failed because line 172 is changed but line 176 is not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave it as is - I don't want to have to re-rest and the CI just finished.

@alexbaden alexbaden force-pushed the alex/verify_dpas_tensor_layout branch from e88ea9a to d947c54 Compare June 23, 2025 18:52
@alexbaden alexbaden merged commit 50fc4c3 into main Jun 23, 2025
18 of 24 checks passed
@alexbaden alexbaden deleted the alex/verify_dpas_tensor_layout branch June 23, 2025 21:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dot op verifier should verify that dpas layout encoding and element type are aligned Investigate failing DPAS validation
6 participants